{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### INTERVAL RNA-seq sample and gene QC for misexpression"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from pathlib import Path\n",
    "import pysam\n",
    "from collections import Counter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "wkdir = \"/lustre/scratch126/humgen/projects/interval_rna/interval_rna_seq/thomasVDS/misexpression_v3/\"\n",
    "wkdir_path = Path(wkdir)\n",
    "\n",
    "# inputs \n",
    "count_matrix_path = \"/lustre/scratch126/humgen/projects/interval_rna/interval_rna_seq_n5188_v97/results/combined/5591-star-fc-genecounts.txt\"\n",
    "rna_id_pass_qc_path = \"/lustre/scratch126/humgen/projects/interval_rna/interval_rna_seq/thomasVDS/misexpression/1_express_quant/sample_qc/pass_qc/rna_id_pass_qc.tsv\"\n",
    "# input gtf must be collapsed, sorted and tabix indexed \n",
    "gencode_gtf_path = wkdir_path.joinpath(\"reference/gencode/gencode.v31.annotation.collapsed.sorted.gtf.gz\")\n",
    "gtex_eqtl_dir = wkdir_path.joinpath(\"reference/gtex/eqtl/GTEx_Analysis_v8_eQTL_expression_matrices/\")\n",
    "out_dir = wkdir_path.joinpath(\"1_rna_seq_qc\")\n",
    "\n",
    "# variables \n",
    "aberrant_multiple = 5\n",
    "tpm_fract_cutoff = 0.1\n",
    "smpl_fraction = 0.05"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# constants \n",
    "AUTOSOMES = ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', \n",
    "             'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', \n",
    "             'chr18', 'chr19', 'chr20', 'chr21', 'chr22']\n",
    "GENE_TYPES = [\"lncRNA\", \"protein_coding\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of samples in feature count matrix: 5488\n",
      "Number of samples after removing samples ending in '_1' and '_2': 4778\n",
      "Checking swaps ...\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "Done.\n",
      "Number of samples passing sample QC: 4731\n"
     ]
    }
   ],
   "source": [
    "# check output directory exists, if not make directory \n",
    "out_dir_path = Path(out_dir)\n",
    "out_dir_path.mkdir(parents=True, exist_ok=True)\n",
    "# read in the count matrix\n",
    "count_matrix_df = pd.read_csv(count_matrix_path, sep=\"\\t\").set_index(\"ENSEMBL_ID\")\n",
    "\n",
    "### sample QC \n",
    "\n",
    "# keep samples with 2x reads and remove 1x samples from batch 1 and batch 15\n",
    "rna_id_list = count_matrix_df.columns.tolist()\n",
    "print(f\"Number of samples in feature count matrix: {len(rna_id_list)}\")\n",
    "rna_id_to_keep = [rna_id for rna_id in rna_id_list if not rna_id.endswith(\"_1\") if not rna_id.endswith(\"_2\")]\n",
    "\n",
    "# remove columns with run1 and run2 sample IDs \n",
    "count_matrix_dplcts_rmvd_df = count_matrix_df[rna_id_to_keep]\n",
    "print(f\"Number of samples after removing samples ending in '_1' and '_2': {count_matrix_dplcts_rmvd_df.shape[1]}\")\n",
    "\n",
    "# swap samples \n",
    "swap_samples = [\"INT_RNA7879032\", \"INT_RNA7879033\", \"INT_RNA7960192\",\n",
    "                \"INT_RNA7960193\", \"INT_RNA7709692\", \"INT_RNA7709693\",\n",
    "                \"INT_RNA7710161\", \"INT_RNA7710162\", \"INT_RNA7710163\", \n",
    "                \"INT_RNA7710164\"]\n",
    "\n",
    "count_matrix_dplcts_rmvd_swapped_df = count_matrix_dplcts_rmvd_df.copy()\n",
    "for i in range(0, int(len(swap_samples)/2)):\n",
    "    swap_1 = swap_samples[2*i]\n",
    "    swap_2 = swap_samples[2*i + 1]\n",
    "    a, b = count_matrix_dplcts_rmvd_df.loc[:, swap_1], count_matrix_dplcts_rmvd_df.loc[:, swap_2]\n",
    "    count_matrix_dplcts_rmvd_swapped_df.loc[:, swap_2], count_matrix_dplcts_rmvd_swapped_df.loc[:, swap_1] = a,b\n",
    "\n",
    "# check swaps\n",
    "print(\"Checking swaps ...\")\n",
    "for i in range(0,int(len(swap_samples)/2)):\n",
    "    swap_1= swap_samples[i*2]\n",
    "    swap_2 = swap_samples[i*2+1]\n",
    "    print(count_matrix_dplcts_rmvd_swapped_df[swap_1].tolist() == count_matrix_dplcts_rmvd_df[swap_2].tolist())\n",
    "print(\"Done.\")\n",
    "# subset to samples passing QC \n",
    "rna_id_pass_qc = pd.read_csv(rna_id_pass_qc_path, sep=\"\\t\", header=None)[0].tolist()\n",
    "smpl_number = len(rna_id_pass_qc)\n",
    "print(f\"Number of samples passing sample QC: {smpl_number}\")\n",
    "\n",
    "# subset count matrix\n",
    "count_matrix_dplcts_rmvd_swapped_pass_qc_df = count_matrix_dplcts_rmvd_swapped_df[rna_id_pass_qc]\n",
    "\n",
    "# check \n",
    "if count_matrix_dplcts_rmvd_swapped_pass_qc_df.shape[1] != len(rna_id_pass_qc): \n",
    "    raise ValueError(\"Samples passing QC missing from count matrix.\")\n",
    "\n",
    "fc_mtx_pass_smpl_qc_df = count_matrix_dplcts_rmvd_swapped_pass_qc_df.reset_index().rename(columns={\"ENSEMBL_ID\": \"gene_id\"})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "QCing genes ...\n",
      "- Number of gene ids in count matrix: 60617\n",
      "- Total gene IDs in gencode file: 59752\n",
      "- Total unique gene IDs in gencode file: 59708\n",
      "- Total genes remaining: 59144\n"
     ]
    }
   ],
   "source": [
    "### Gene QC \n",
    "print(\"QCing genes ...\")\n",
    "gene_ids_in_fx_mtx = set(fc_mtx_pass_smpl_qc_df.gene_id)\n",
    "print(f\"- Number of gene ids in count matrix: {len(gene_ids_in_fx_mtx)}\")\n",
    "# get gene IDs from input gencode gtf \n",
    "gencode_gene_ids = [gtf.gene_id.split(\".\")[0] for gtf in pysam.TabixFile(str(gencode_gtf_path)).fetch(parser = pysam.asGTF()) if gtf.feature == \"gene\"]\n",
    "print(f\"- Total gene IDs in gencode file: {len(gencode_gene_ids)}\")\n",
    "# remove duplicate gene IDs\n",
    "gencode_gene_ids_uniq = set(gencode_gene_ids)\n",
    "print(f\"- Total unique gene IDs in gencode file: {len(gencode_gene_ids_uniq)}\")\n",
    "# get globin genes\n",
    "globin_gene_symbols = [\"HBA1\",\"HBA2\", \"HBB\", \"HBD\", \"HBE1\", \"HBG1\", \"HBG2\", \"HBM\", \"HBQ1\", \"HBZ\", \"MB\"]\n",
    "globin_gene_ids = [gtf.gene_id.split(\".\")[0] for gtf in pysam.TabixFile(str(gencode_gtf_path)).fetch(parser = pysam.asGTF()) if gtf.feature == \"gene\" and gtf.gene_name in globin_gene_symbols]\n",
    "if len(globin_gene_symbols) != len(globin_gene_ids): \n",
    "    raise ValueError(\"Missing globin gene IDs in gencode file.\") \n",
    "# get list of entries which contain \"rRNA\" in class\n",
    "rRNA_gene_ids = [gtf.gene_id.split(\".\")[0] for gtf in pysam.TabixFile(str(gencode_gtf_path)).fetch(parser = pysam.asGTF()) if \"rRNA\" in gtf.gene_type and gtf.feature == \"gene\"]\n",
    "rRNA_globin_gene_ids = set(globin_gene_ids).union(set(rRNA_gene_ids))\n",
    "# subset gene expression matrix \n",
    "genes_to_keep = gencode_gene_ids_uniq - rRNA_globin_gene_ids\n",
    "gene_number = len(genes_to_keep)\n",
    "fc_mtx_pass_smpl_gene_qc_df = fc_mtx_pass_smpl_qc_df[fc_mtx_pass_smpl_qc_df.gene_id.isin(genes_to_keep)].set_index(\"gene_id\")\n",
    "# check \n",
    "if fc_mtx_pass_smpl_gene_qc_df.shape[0] != len(genes_to_keep):\n",
    "    raise ValueError(\"Genes passing QC missing from count matrix.\")\n",
    "print(f\"- Total genes remaining: {gene_number}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# write to file \n",
    "fc_out_dir = out_dir_path.joinpath(\"fc_mtx\")\n",
    "Path(fc_out_dir).mkdir(parents=True, exist_ok=True)\n",
    "fc_mtx_path = fc_out_dir.joinpath(f\"fc_mtx_{smpl_number}samples_{gene_number}genes.csv\")\n",
    "fc_mtx_pass_smpl_gene_qc_df.to_csv(fc_mtx_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Calculating TPM ...\n",
      "Done.\n"
     ]
    }
   ],
   "source": [
    "### calculate TPM \n",
    "print(f\"Calculating TPM ...\")\n",
    "# calculate total exon length \n",
    "gene_id_exon_len = {}\n",
    "for gtf in pysam.TabixFile(str(gencode_gtf_path)).fetch(parser = pysam.asGTF()):\n",
    "    if gtf.feature == \"exon\":\n",
    "        gene_id = gtf.gene_id\n",
    "        exon_len = int(gtf.end) - int(gtf.start) \n",
    "        # check if there are zero length exons\n",
    "        if exon_len == 0: \n",
    "            raise ValueError(f\"{gene_id} has zero length exon.\")\n",
    "        gene_id_exon_len[gene_id] = gene_id_exon_len.get(gene_id, 0) + exon_len \n",
    "\n",
    "# converted exon length dictionary to dataframe and clean \n",
    "gene_id_exon_len_df = pd.DataFrame.from_dict(gene_id_exon_len, orient=\"index\", columns=[\"length_exons\"]).reset_index(drop=False).rename(columns={\"index\": \"gene_id\"})\n",
    "gene_id_exon_len_no_vrsn_no_pary_df = gene_id_exon_len_df.copy()\n",
    "gene_id_exon_len_no_vrsn_no_pary_df = gene_id_exon_len_no_vrsn_no_pary_df[~gene_id_exon_len_no_vrsn_no_pary_df.gene_id.str.endswith(\"PAR_Y\")]\n",
    "gene_id_exon_len_no_vrsn_no_pary_df.gene_id = gene_id_exon_len_no_vrsn_no_pary_df.gene_id.str.split(\".\").str[0]\n",
    "# subset exon lengths to genes passing filters \n",
    "gene_id_exon_len_gene_pass_qc_df = gene_id_exon_len_no_vrsn_no_pary_df[gene_id_exon_len_no_vrsn_no_pary_df.gene_id.isin(genes_to_keep)].set_index(\"gene_id\").sort_index()\n",
    "\n",
    "# check gene order is identical \n",
    "if gene_id_exon_len_gene_pass_qc_df.index.tolist() != fc_mtx_pass_smpl_gene_qc_df.index.tolist(): \n",
    "    raise ValueError(\"Gene lists in count matrix and exon length matrix do not match.\")\n",
    "\n",
    "# compute RPK (reads per kilobase) - divide read count by exon length/1000\n",
    "rpk_df = fc_mtx_pass_smpl_gene_qc_df.div((gene_id_exon_len_gene_pass_qc_df.length_exons/1000), axis=0).T\n",
    "# divide by sum of RPK values to calculate TPM \n",
    "tpm_df = rpk_df.div(rpk_df.sum(axis=1)/1000000, axis=0)\n",
    "tpm_tp_df = tpm_df.T\n",
    "# check TPM sums to 1 million\n",
    "if not round(tpm_tp_df.sum(axis=0)).unique().tolist() == [1000000.0]:\n",
    "    raise ValueError(\"TPMs do not sum to 1,000,000.\")\n",
    "print(\"Done.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# write to file \n",
    "tpm_out_dir = out_dir_path.joinpath(\"tpm_mtx\")\n",
    "Path(tpm_out_dir).mkdir(parents=True, exist_ok=True)\n",
    "tpm_mtx_path = tpm_out_dir.joinpath(f\"tpm_{smpl_number}samples_{gene_number}genes.csv\")\n",
    "tpm_tp_df.to_csv(tpm_mtx_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Removing global overexpression outlier samples ...\n",
      "Number of genes to count top expression events: 57555\n",
      "Total top expression events: 57555\n",
      "Differences in top expression events: 0\n"
     ]
    }
   ],
   "source": [
    "### aberrant sample QC \n",
    "print(\"Removing global overexpression outlier samples ...\")\n",
    "# remove genes that have TPM = 0 for all samples \n",
    "tpm_no_zeroes_df = tpm_tp_df.loc[~(tpm_tp_df==0).all(axis=1)]\n",
    "genes_with_max_express = tpm_no_zeroes_df.shape[0]\n",
    "print(f\"Number of genes to count top expression events: {genes_with_max_express}\")\n",
    "# count max sample for each gene \n",
    "max_gene_per_smpl_genes_dict = Counter(tpm_no_zeroes_df.idxmax(axis=1).tolist())\n",
    "max_gene_per_smpl_genes_df = pd.DataFrame.from_dict(max_gene_per_smpl_genes_dict, orient=\"index\", columns=[\"count\"]).reset_index(drop=False).rename(columns={\"index\":\"rna_id\"})\n",
    "# add samples with no top expression events \n",
    "smpl_with_no_top_event_dict = {rna_id: 0 for rna_id in rna_id_pass_qc if rna_id not in max_gene_per_smpl_genes_dict.keys()}\n",
    "smpl_with_no_top_event_df = pd.DataFrame.from_dict(smpl_with_no_top_event_dict, orient=\"index\", columns=[\"count\"]).reset_index(drop=False).rename(columns={\"index\":\"rna_id\"})\n",
    "max_gene_per_smpl_genes_all_df = pd.concat([max_gene_per_smpl_genes_df, smpl_with_no_top_event_df]).reset_index(drop=True)\n",
    "print(f\"Total top expression events: {max_gene_per_smpl_genes_all_df['count'].sum()}\")\n",
    "\n",
    "# if multiple columns have same max value idxmax function selects sample in nearest column\n",
    "# check that this does not lead to big differences in number of top expression events per sample\n",
    "# reverse dataframe \n",
    "tpm_no_zeroes_df_rev = tpm_no_zeroes_df.iloc[:, ::-1]\n",
    "max_gene_per_smpl_genes_rev_dict = Counter(tpm_no_zeroes_df_rev.idxmax(axis=1).tolist())\n",
    "max_gene_per_smpl_genes_rev_df = pd.DataFrame.from_dict(max_gene_per_smpl_genes_rev_dict, orient=\"index\", columns=[\"count_rev\"]).reset_index(drop=False).rename(columns={\"index\":\"rna_id\"})\n",
    "# add samples with no top expression events \n",
    "smpl_with_no_top_event_rev_dict = {rna_id: 0 for rna_id in rna_id_pass_qc if rna_id not in max_gene_per_smpl_genes_rev_dict.keys()}\n",
    "smpl_with_no_top_event_rev_df = pd.DataFrame.from_dict(smpl_with_no_top_event_rev_dict, orient=\"index\", columns=[\"count_rev\"]).reset_index(drop=False).rename(columns={\"index\":\"rna_id\"})\n",
    "max_gene_per_smpl_genes_all_rev_df = pd.concat([max_gene_per_smpl_genes_rev_df, smpl_with_no_top_event_rev_df]).reset_index(drop=True)\n",
    "max_gene_per_smpl_genes_all_rev_merged_df = pd.merge(max_gene_per_smpl_genes_all_df, max_gene_per_smpl_genes_all_rev_df, how=\"inner\", on=\"rna_id\")\n",
    "max_gene_per_smpl_genes_all_rev_merged_df[\"identical\"] = np.where(max_gene_per_smpl_genes_all_rev_merged_df[\"count\"] == max_gene_per_smpl_genes_all_rev_merged_df[\"count_rev\"], True, False)\n",
    "not_identical = max_gene_per_smpl_genes_all_rev_merged_df[~max_gene_per_smpl_genes_all_rev_merged_df[\"identical\"]].shape[0]\n",
    "print(f\"Differences in top expression events: {not_identical}\")\n",
    "exp_max_express = genes_with_max_express/smpl_number"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Expected number of top gene expression outliers: 12.165504121750159\n",
      "Top gene expression outliers cutoff: 60.827520608750795\n",
      "Number of samples failing QC: 163\n",
      "Percentage of samples failing QC: 3.4453603889241173\n",
      "Number of samples passing QC: 4568\n",
      "Done.\n"
     ]
    }
   ],
   "source": [
    "# aberrant sample QC metrics \n",
    "print(f\"Expected number of top gene expression outliers: {exp_max_express}\")\n",
    "max_express_cutoff = aberrant_multiple * exp_max_express\n",
    "print(f\"Top gene expression outliers cutoff: {max_express_cutoff}\")\n",
    "rna_id_fail_qc = max_gene_per_smpl_genes_all_df[max_gene_per_smpl_genes_all_df[\"count\"] > max_express_cutoff].rna_id.tolist()\n",
    "num_smpls_fail_qc = len(rna_id_fail_qc)\n",
    "print(f\"Number of samples failing QC: {num_smpls_fail_qc}\")\n",
    "print(f\"Percentage of samples failing QC: {(num_smpls_fail_qc/smpl_number)*100}\")\n",
    "# keep samples passing QC \n",
    "rna_ids_pass_qc = max_gene_per_smpl_genes_all_df[~max_gene_per_smpl_genes_all_df.rna_id.isin(rna_id_fail_qc)].rna_id.unique()\n",
    "num_rna_id_pass_qc = len(rna_ids_pass_qc)\n",
    "print(f\"Number of samples passing QC: {num_rna_id_pass_qc}\")\n",
    "max_gene_per_smpl_genes_all_df[\"pass_qc\"] = np.where(max_gene_per_smpl_genes_all_df.rna_id.isin(rna_ids_pass_qc), \"Pass\", \"Fail\")\n",
    "# subset expression matrix \n",
    "tpm_smpl_qc_df = tpm_tp_df[rna_ids_pass_qc]\n",
    "print(\"Done.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# write to file \n",
    "aberrant_smpl_qc_dir = out_dir_path.joinpath(\"aberrant_smpl_qc\")\n",
    "Path(aberrant_smpl_qc_dir).mkdir(parents=True, exist_ok=True)\n",
    "aberrant_smpl_qc_path = aberrant_smpl_qc_dir.joinpath(\"aberrant_gene_count.csv\")\n",
    "max_gene_per_smpl_genes_all_srtd_df = max_gene_per_smpl_genes_all_df.sort_values(by=\"count\", ascending=False)\n",
    "max_gene_per_smpl_genes_all_srtd_df.to_csv(aberrant_smpl_qc_path, index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# write list of samples passing QC to file \n",
    "smpls_pass_qc_path = aberrant_smpl_qc_dir.joinpath(f\"smpls_pass_qc_{num_rna_id_pass_qc}.csv\")\n",
    "with open(smpls_pass_qc_path, 'w') as f_out: \n",
    "    for rna_id in rna_ids_pass_qc: \n",
    "        f_out.write(f\"{rna_id}\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "# write TPM matrix with aberrant samples removed \n",
    "tpm_post_smpl_qc_path = tpm_out_dir.joinpath(f\"tpm_{num_rna_id_pass_qc}samples_{gene_number}genes_smpl_qc.csv\")\n",
    "tpm_smpl_qc_df.to_csv(tpm_post_smpl_qc_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "# write count matrix with aberrant samples removed \n",
    "fc_mtx_smpl_qc_path = fc_out_dir.joinpath(f\"fc_mtx_{num_rna_id_pass_qc}samples_{gene_number}genes.csv\")\n",
    "fc_mtx_gene_qc_smpl_qc_pass_df = fc_mtx_pass_smpl_gene_qc_df[rna_ids_pass_qc]\n",
    "fc_mtx_gene_qc_smpl_qc_pass_df.to_csv(fc_mtx_smpl_qc_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of genes passing eQTL QC in at least one tissue in GTEx: 39832\n",
      "Number of genes passing gene QC with eQTL in at least one tissue: 39402\n"
     ]
    }
   ],
   "source": [
    "### inactive gene identification\n",
    "\n",
    "# subset to genes passing eQTL QC in at least one GTEx tissue \n",
    "gtex_eqtl_path = Path(gtex_eqtl_dir)\n",
    "normalised_gene_exp_path_list = gtex_eqtl_path.glob(\"*normalized_expression.bed.gz\")\n",
    "# collect genes passing eQTL QC in at least one tissue \n",
    "gene_ids_passed_eqtl_qc_list = []\n",
    "for tissue_normalised_exp_path in normalised_gene_exp_path_list:\n",
    "    gene_ids_passed_eqtl_qc_list += pd.read_csv(tissue_normalised_exp_path, sep=\"\\t\").gene_id.tolist()\n",
    "\n",
    "gene_pass_eqtl_count_df = pd.DataFrame.from_dict(Counter(gene_ids_passed_eqtl_qc_list), orient=\"index\", columns=[\"tissues_pass_eqtl_qc\"])\n",
    "gene_pass_eqtl_count_idx_df = gene_pass_eqtl_count_df.reset_index().rename(columns={\"index\":\"gene_id_vrsn\"})\n",
    "gene_pass_eqtl_count_idx_df[\"gene_id\"] = gene_pass_eqtl_count_idx_df.gene_id_vrsn.str.split(\".\").str[0]\n",
    "\n",
    "# set keeping gene version \n",
    "gene_ids_passed_eqtl_set = set(gene_ids_passed_eqtl_qc_list)\n",
    "# set without gene version \n",
    "gene_ids_passed_eqtl_no_vrsn_set = set([gene_id.split(\".\")[0] for gene_id in gene_ids_passed_eqtl_qc_list])\n",
    "print(f\"Number of genes passing eQTL QC in at least one tissue in GTEx: {len(gene_ids_passed_eqtl_no_vrsn_set)}\")\n",
    "genes_to_keep_pass_eqtl = genes_to_keep.intersection(gene_ids_passed_eqtl_no_vrsn_set)\n",
    "print(f\"Number of genes passing gene QC with eQTL in at least one tissue: {len(genes_to_keep_pass_eqtl)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of genes passing all filters: 29614\n",
      "Genes with TPM > 0.1 in 5.0% samples: 8779\n"
     ]
    }
   ],
   "source": [
    "# write number of tissues gene pass eQTL QC in \n",
    "gtex_pass_eqtl_dir = out_dir_path.joinpath(\"gtex_pass_eqtl_qc\")\n",
    "Path(gtex_pass_eqtl_dir).mkdir(parents=True, exist_ok=True)\n",
    "gtex_pass_eqtl_count_path = gtex_pass_eqtl_dir.joinpath(\"gene_pass_eqtl_count.csv\")\n",
    "gene_pass_eqtl_count_idx_df.to_csv(gtex_pass_eqtl_count_path, sep=\",\", index=False)\n",
    "\n",
    "# subset to protein-coding and lncRNA genes on autosomes \n",
    "genes_pass_filters = []\n",
    "for gtf in pysam.TabixFile(str(gencode_gtf_path)).fetch(parser = pysam.asGTF()):\n",
    "    if gtf.gene_id.split(\".\")[0] in genes_to_keep_pass_eqtl and gtf.feature == \"gene\": \n",
    "        if gtf.contig in AUTOSOMES and gtf.gene_type in GENE_TYPES: \n",
    "            genes_pass_filters.append(gtf.gene_id.split(\".\")[0])\n",
    "if len(genes_pass_filters) != len(set(genes_pass_filters)):\n",
    "    raise ValueError(\"Duplicates in final gene set.\")\n",
    "print(f\"Number of genes passing all filters: {len(genes_pass_filters)}\")\n",
    "# subset expression matrix \n",
    "tpm_smpl_gene_qc_df = tpm_smpl_qc_df[tpm_smpl_qc_df.index.isin(genes_pass_filters)]\n",
    "# compute TPM < 0.1 fraction per gene\n",
    "tpm_fract_df = tpm_smpl_gene_qc_df.reset_index(drop=False)[[\"gene_id\"]].copy()\n",
    "tpm_fract_df[f\"tpm{tpm_fract_cutoff}_fract\"] = ((tpm_smpl_gene_qc_df > tpm_fract_cutoff).sum(axis=1)/num_rna_id_pass_qc).tolist()\n",
    "# subset to genes with TPM > 0.1 in 5% of samples \n",
    "genes_tpm01_fract_grtr95 = set(tpm_fract_df[tpm_fract_df[f\"tpm{tpm_fract_cutoff}_fract\"] < smpl_fraction].gene_id)\n",
    "num_inactive_genes = len(genes_tpm01_fract_grtr95)\n",
    "print(f\"Genes with TPM > {tpm_fract_cutoff} in {smpl_fraction*100}% samples: {num_inactive_genes}\")\n",
    "# subset expression matrix \n",
    "tpm_inactive_df = tpm_smpl_gene_qc_df[tpm_smpl_gene_qc_df.index.isin(genes_tpm01_fract_grtr95)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "# write TPM > 0.1 fraction per gene to file\n",
    "gene_set_dir = out_dir_path.joinpath(\"gene_sets\")\n",
    "Path(gene_set_dir).mkdir(parents=True, exist_ok=True)\n",
    "gene_tpm_fract_path = gene_set_dir.joinpath(f\"gene_tpm{tpm_fract_cutoff}_fract.csv\")\n",
    "tpm_fract_df.to_csv(gene_tpm_fract_path, index=False)\n",
    "# write inactive gene set to file\n",
    "inactive_gene_path = gene_set_dir.joinpath(f\"inactive_genes_{num_inactive_genes}.txt\")\n",
    "with open(inactive_gene_path, \"w\") as f_out: \n",
    "    for gene_id in genes_tpm01_fract_grtr95: \n",
    "        f_out.write(f\"{gene_id}\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "# write inactive gene expression matrix  \n",
    "inactive_tpm_out_dir = out_dir_path.joinpath(\"tpm_mtx_inactive\")\n",
    "Path(inactive_tpm_out_dir).mkdir(parents=True, exist_ok=True)\n",
    "inactive_tpm_mtx_path = inactive_tpm_out_dir.joinpath(f\"tpm_{num_rna_id_pass_qc}samples_{num_inactive_genes}genes_inactive.tsv\")\n",
    "tpm_inactive_df.to_csv(inactive_tpm_mtx_path, sep=\"\\t\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of inactive genes with at least one non-zero TPM: 8739\n"
     ]
    }
   ],
   "source": [
    "### Gene expression z-score calculation\n",
    "\n",
    "# remove genes where all samples TPM = 0\n",
    "tpm_inactive_rmv_all_zeroes_df = tpm_inactive_df[(tpm_inactive_df != 0).any(axis=1)]\n",
    "inactive_genes_rmv_all_zeroes = tpm_inactive_rmv_all_zeroes_df.index.unique()\n",
    "num_inactive_genes_rmv_all_zeroes = len(inactive_genes_rmv_all_zeroes)\n",
    "print(f\"Number of inactive genes with at least one non-zero TPM: {num_inactive_genes_rmv_all_zeroes}\")\n",
    "# calculate gene z-scores \n",
    "tpm_inactive_tp_df = tpm_inactive_rmv_all_zeroes_df.transpose()\n",
    "zscore_tp_df = (tpm_inactive_tp_df - tpm_inactive_tp_df.mean())/tpm_inactive_tp_df.std()\n",
    "zscore_df = zscore_tp_df.transpose()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "# flatten z-score and TPM expression matrix \n",
    "z_score_flat_df = pd.melt(zscore_df, ignore_index=False).reset_index(drop=False).rename(columns={\"variable\":\"rna_id\", \"value\":\"z-score\"})\n",
    "tpm_inactive_flat_df = pd.melt(tpm_inactive_rmv_all_zeroes_df, ignore_index=False).reset_index(drop=False).rename(columns={\"variable\":\"rna_id\", \"value\":\"TPM\"})\n",
    "# merge z-score and TPM \n",
    "tpm_zscore_inactive_flat_df = pd.merge(tpm_inactive_flat_df, z_score_flat_df, on=[\"gene_id\", \"rna_id\"], how=\"inner\")\n",
    "# write flat z-score TPM matrix \n",
    "z_tpm_flat_out_dir = out_dir_path.joinpath(\"zscore_tpm_flat\")\n",
    "Path(z_tpm_flat_out_dir).mkdir(parents=True, exist_ok=True)\n",
    "tpm_zscore_path_out = z_tpm_flat_out_dir.joinpath(f\"tpm_zscore_{num_rna_id_pass_qc}smpls_{num_inactive_genes_rmv_all_zeroes}genes_tpm{tpm_fract_cutoff}_frac_{smpl_fraction*100}perc_flat.csv\")\n",
    "tpm_zscore_inactive_flat_df.to_csv(tpm_zscore_path_out, index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
